#cython: boundscheck=False
#cython: cdivision=True

cimport cython
import numpy as np

cimport numpy as np

from scipy.signal import convolve2d, fftconvolve
from scipy.sparse import csc_matrix, lil_matrix
from scipy.special import expit

from libc.math cimport exp
from libc.stdlib cimport malloc, free
import bottleneck as bn

FLOAT_TYPE = np.float32
INT_TYPE = np.int32
U_INT_TYPE = np.uint32
ctypedef np.float32_t FLOAT_TYPE_t
ctypedef np.int32_t INT_TYPE_t
ctypedef np.uint32_t U_INT_TYPE_t

cdef FLOAT_TYPE_t _inner_sigmoid(FLOAT_TYPE_t x):
    return 1. / (1. + exp(-x))

cdef _sigmoid(const unsigned int n_samples, const unsigned int n_features,
              np.ndarray[FLOAT_TYPE_t, ndim=2] X,
              np.ndarray[FLOAT_TYPE_t, ndim=2] out):
    cdef unsigned int i, j
    for i in range(n_samples):
        for j in range(n_features):
            out[i, j] = _inner_sigmoid(X[i, j])
    return out

def sigmoid_1(X, out=None):
    is_1d = X.ndim == 1
    X = np.asarray(np.atleast_2d(X))

    n_samples, n_features = X.shape

    if out is None:
        out = np.empty_like(X)

    _sigmoid(n_samples, n_features, X, out)

    if is_1d:
        return np.squeeze(out)
    return out

def sigmoid_2(np.ndarray v):
    cdef np.ndarray e = np.exp(-v)
    return 1 / (1 + e)

def sigmoid_3(v):
    return expit(v)

sigmoid_standard = sigmoid_2
sigmoid_fast = sigmoid_3

cdef FLOAT_TYPE_t _inner_tanh(FLOAT_TYPE_t x):
    return (1. - exp(-x)) / (1. + exp(-x))

cdef _tanh(const unsigned int n_samples, const unsigned int n_features,
              np.ndarray[FLOAT_TYPE_t, ndim=2] X,
              np.ndarray[FLOAT_TYPE_t, ndim=2] out):
    cdef unsigned int i, j
    for i in range(n_samples):
        for j in range(n_features):
            out[i, j] = _inner_tanh(X[i, j])
    return out

def tanh_1(X, out=None):
    is_1d = X.ndim == 1
    X = np.asarray(np.atleast_2d(X))

    n_samples, n_features = X.shape

    if out is None:
        out = np.empty_like(X)

    _tanh(n_samples, n_features, X, out)

    if is_1d:
        return np.squeeze(out)
    return out

def tanh_2(v):
    return np.clip(np.tanh(v), -0.99999999, 0.99999999)

tanh_fast = tanh_2
tanh_standard = tanh_2

def softmax_1(v):
    e = np.exp(v)
    return e / np.sum(e)

cpdef softmax_2(np.ndarray[FLOAT_TYPE_t, ndim=1] v):
    cdef np.ndarray[FLOAT_TYPE_t, ndim=1] e = np.exp(v)
    return np.clip(e / np.sum(e), 0.00000001, 0.99999999)

softmax_standard = softmax_2
softmax_fast = softmax_1

cdef _inner_input_layer_backward(FLOAT_TYPE_t *grad_W, FLOAT_TYPE_t *delta, \
                                 const unsigned int wordvec_dim, const unsigned int vocab_size, \
                                 const FLOAT_TYPE_t decay, FLOAT_TYPE_t *W, U_INT_TYPE_t*indexes, \
                                 const unsigned int indexes_len):
    cdef unsigned int i, j, index = 0
    for i in range(wordvec_dim):
        for j in range(indexes_len):
            grad_W[i * vocab_size + indexes[j]] += delta[i * indexes_len + j]
    # comment these lines for speed consideration
    '''for i in range(wordvec_dim):
        for j in range(vocab_size):
            grad_W[index] += decay * W[index]
            index += 1'''

def input_layer_backward_1(np.ndarray[FLOAT_TYPE_t, ndim=3] delta not None, const FLOAT_TYPE_t decay, \
                           np.ndarray[FLOAT_TYPE_t, ndim=2] W not None, \
                           np.ndarray[U_INT_TYPE_t, ndim=1] indexes not None):
    cdef np.ndarray[FLOAT_TYPE_t, ndim=2] grad_W = np.zeros((W.shape[0], W.shape[1]), dtype=FLOAT_TYPE)
    _inner_input_layer_backward(&grad_W[0, 0], &delta[0, 0, 0], W.shape[0], \
                                W.shape[1], decay, &W[0, 0], &indexes[0], len(indexes))
    return grad_W

def input_layer_backward_2(delta, decay, W, indexes):
    # delta is with shape of (1, wordvec_dim, len(indexes))
    grad_W = np.zeros(W.shape, dtype=FLOAT_TYPE)
    for i in range(len(indexes)):
        grad_W[:, indexes[i]] += delta[0, :, i]
    # comment this line for speed consideration
    # grad_W += decay * W
    return grad_W

def input_layer_backward_3(delta, decay, W, indexes):
    # delta is with shape of (1, wordvec_dim, len(indexes))
    index_map = {}
    for i in range(len(indexes)):
        index = indexes[i]
        if index in index_map:
            index_map[index] += delta[0, :, i]
        else:
            index_map[index] = delta[0, :, i]
    grad_W = np.zeros(W.shape, dtype=FLOAT_TYPE)
    for k, v in index_map.iteritems():
        grad_W[:, k] = v
    #grad_W += decay * W
    return grad_W

def input_layer_backward_special(delta, decay, W, indexes, grad):
    index_map = {}
    cdef unsigned int i
    for i in range(<unsigned int> len(indexes)):
        index = indexes[i]
        if index in index_map:
            index_map[index] += delta[0, :, i]
        else:
            index_map[index] = delta[0, :, i]
    for k, v in index_map.iteritems():
        if k in grad:
            grad[k] += v
        else:
            grad[k] = v

input_layer_backward_standard = input_layer_backward_2
input_layer_backward_fast = input_layer_backward_1

def _inner_wide_convolution(input, input_rows, input_cols, filter, filter_cols):
    return _inner_convolve2d(input, input.shape[0], input.shape[1], filter, filter.shape[0], filter.shape[1], 1)

@cython.boundscheck(False)
cdef _inner_convolve2d(np.ndarray[FLOAT_TYPE_t, ndim=2] input, const unsigned int input_rows,
                       const unsigned int input_cols,
                       np.ndarray[FLOAT_TYPE_t, ndim=2] filter, const unsigned int filter_rows,
                       const unsigned int filter_cols,
                       const int mode):
    cdef np.ndarray[FLOAT_TYPE_t, ndim=2] input_padded, output
    cdef unsigned int input_rows_padded, input_cols_padded, i, j, m, n, output_rows, output_cols = 0
    if mode == 1: # full mode
        output_rows = input_rows + filter_rows - 1
        output_cols = input_cols + filter_cols - 1
        input_rows_padded = input_rows + 2 * filter_rows - 2
        input_cols_padded = input_cols + 2 * filter_cols - 2
        input_padded = np.empty((input_rows_padded, input_cols_padded), dtype=FLOAT_TYPE)
        for i in range(input_rows_padded):
            for j in range(input_cols_padded):
                if i < filter_rows - 1 or i >= input_rows + filter_rows - 1:
                    input_padded[i, j] = 0.0
                elif j < filter_cols - 1 or j >= input_cols + filter_cols - 1:
                    input_padded[i, j] = 0.0
                else:
                    input_padded[i, j] = input[<unsigned int> i - filter_rows + 1, <unsigned int> j - filter_cols + 1]
    elif mode == 0: # valid mode
        output_rows = input_rows - filter_rows + 1
        output_cols = input_cols - filter_cols + 1
        input_rows_padded = input_rows
        input_cols_padded = input_cols
        input_padded = input
    else:
        raise ValueError
    cdef FLOAT_TYPE_t wt
    output = np.empty((output_rows, output_cols), dtype=FLOAT_TYPE)
    for i in range(output_rows):
        for j in range(output_cols):
            wt = 0.0
            for m in range(filter_rows):
                for n in range(filter_cols):
                    wt += input_padded[<unsigned int> i + m, <unsigned int> j + n] * \
                          filter[<unsigned int> filter_rows - m - 1, <unsigned int> filter_cols - n - 1]
            output[i, j] = wt
    return output

def wide_convolution2d_1(input, filter):
    filter2d = np.atleast_2d(filter)
    return _inner_wide_convolution(input, input.shape[0], input.shape[1], filter2d, max(filter2d.shape))

def wide_convolution2d_2(input, filter):
    return convolve2d(input, np.atleast_2d(filter))

def wide_convolution2d_3(input, filter):
    filter2d = np.atleast_2d(filter)
    return _inner_convolve2d(input, input.shape[0], input.shape[1], filter2d, filter2d.shape[0], filter2d.shape[1], 1)

wide_convolution2d_standard = wide_convolution2d_2
wide_convolution2d_fast = wide_convolution2d_1

def convolve2d_1(input, filter, mode='full'):
    if mode == 'full':
        return _inner_convolve2d(input, input.shape[0], input.shape[1], filter, filter.shape[0], filter.shape[1], 1)
    else:
        return _inner_convolve2d(input, input.shape[0], input.shape[1], filter, filter.shape[0], filter.shape[1], 0)

convolve2d_standard = convolve2d
convolve2d_fast = convolve2d_1

def wide_convolution_layer_forward_1(np.ndarray[FLOAT_TYPE_t, ndim=3] input, const unsigned int window_size, \
                                     const unsigned int n_filters, np.ndarray[FLOAT_TYPE_t, ndim=4] W, \
                                     np.ndarray[FLOAT_TYPE_t, ndim=1] b):
    # input should be with shape of (n_feature_maps, n_feature_rows, n_feature_cols)
    # output should be with shape of (n_filters, n_feature_rows, n_feature_cols+window_size-1)
    # W should be with shape of (n_feature_maps, n_filters, 1, window_size)
    # b should be with shape of (n_filters,)
    cdef unsigned int n_feature_maps = input.shape[0]
    cdef unsigned int n_feature_rows = input.shape[1]
    cdef unsigned int n_feature_cols = input.shape[2]
    cdef np.ndarray[FLOAT_TYPE_t, ndim=3] output = np.zeros(
        (n_filters, n_feature_rows, n_feature_cols + window_size - 1), dtype=FLOAT_TYPE)
    cdef unsigned int i, j
    for i in range(n_filters):
        for j in range(n_feature_maps):
            output[i, :, :] += convolve2d_fast(input[j, :, :], np.rot90(W[j, i, :, :], 2), mode='full')
    for i in range(n_filters):
        output[i, :, :] += b[i]
    return output

def wide_convolution_layer_forward_2(input, window_size, n_filters, W, b):
    # input should be with shape of (n_feature_maps, n_feature_rows, n_feature_cols)
    # output should be with shape of (n_filters, n_feature_rows, n_feature_cols+window_size-1)
    # W should be with shape of (n_feature_maps, n_filters, 1, window_size)
    # b should be with shape of (n_filters,)
    n_feature_maps, n_feature_rows, n_feature_cols = input.shape[0], input.shape[1], input.shape[2]
    output = np.zeros((n_filters, n_feature_rows, n_feature_cols + window_size - 1), dtype=FLOAT_TYPE)
    for i in range(n_filters):
        for j in range(n_feature_maps):
            output[i, :, :] += convolve2d(input[j, :, :], np.rot90(W[j, i, :, :], 2), mode='full')
    for i in range(n_filters):
        output[i] += b[i]
    return output

def wide_convolution_layer_forward_3(np.ndarray[FLOAT_TYPE_t, ndim=3] input, const unsigned int window_size, \
                                     const unsigned int n_filters, np.ndarray[FLOAT_TYPE_t, ndim=4] W, \
                                     np.ndarray[FLOAT_TYPE_t, ndim=1] b):
    # input should be with shape of (n_feature_maps, n_feature_rows, n_feature_cols)
    # output should be with shape of (n_filters, n_feature_rows, n_feature_cols+window_size-1)
    # W should be with shape of (n_feature_maps, n_filters, 1, window_size)
    # b should be with shape of (n_filters,)
    cdef unsigned int n_feature_maps = input.shape[0]
    cdef unsigned int n_feature_rows = input.shape[1]
    cdef unsigned int n_feature_cols = input.shape[2]
    cdef np.ndarray[FLOAT_TYPE_t, ndim=3] output = np.zeros(
        (n_filters, n_feature_rows, n_feature_cols + window_size - 1), dtype=FLOAT_TYPE)
    cdef unsigned int i, j
    for i in range(n_filters):
        for j in range(n_feature_maps):
            output[i, :, :] += wide_convolution2d_fast(input[j, :, :], np.rot90(W[j, i, :, :], 2))
    for i in range(n_filters):
        output[i, :, :] += b[i]
    return output

wide_convolution_layer_forward_standard = wide_convolution_layer_forward_2
wide_convolution_layer_forward_fast = wide_convolution_layer_forward_1

@cython.boundscheck(False)
def wide_convolution_layer_backward_1(np.ndarray[FLOAT_TYPE_t, ndim=3] input_images, \
                                      np.ndarray[FLOAT_TYPE_t, ndim=4] W, \
                                      np.ndarray[FLOAT_TYPE_t, ndim=1] b, \
                                      np.ndarray[FLOAT_TYPE_t, ndim=3] grad, \
                                      const FLOAT_TYPE_t decay, object back_linear=True):
    cdef unsigned int n_feature_maps, n_filters, n_feature_rows, n_feature_cols, i, j
    n_feature_maps = W.shape[0]
    n_filters = W.shape[1]
    n_feature_rows = input_images.shape[1]
    n_feature_cols = input_images.shape[2]
    cdef np.ndarray[FLOAT_TYPE_t, ndim=3] back_grad = np.zeros((n_feature_maps, n_feature_rows, n_feature_cols), dtype=FLOAT_TYPE)
    cdef np.ndarray[FLOAT_TYPE_t, ndim=4] grad_W = np.empty((W.shape[0], W.shape[1], W.shape[2], W.shape[3]), dtype=FLOAT_TYPE)
    cdef np.ndarray[FLOAT_TYPE_t, ndim=1] grad_b = np.empty((b.shape[0],), dtype=FLOAT_TYPE)
    for i in range(n_feature_maps):
        for j in range(n_filters):
            back_grad[i, :, :] += convolve2d_fast(grad[j, :, :], W[i, j, :, :], mode='valid')
    for i in range(n_feature_maps):
        if not back_linear:
            back_grad[i, :, :] = back_grad[i, :, :] * input_images[i, :, :] * (1 - input_images[i, :, :])
    for i in range(n_feature_maps):
        for j in range(n_filters):
            grad_W[i, j, :, :] = convolve2d_fast(np.rot90(grad[j, :, :], 2), input_images[i, :, :], mode='valid')
    grad_W += decay * W
    for i in range(n_filters):
        grad_b[i] = np.sum(grad[i,:,:])
    return back_grad, grad_W, grad_b

def wide_convolution_layer_backward_2(input_images, W, b, grad, decay, back_linear=True):
    back_grad = np.zeros(input_images.shape, dtype=FLOAT_TYPE)
    grad_W, grad_b = np.empty(W.shape, dtype=FLOAT_TYPE), np.empty(b.shape, dtype=FLOAT_TYPE)
    n_feature_maps, n_filters = W.shape[0], W.shape[1]
    for i in range(n_feature_maps):
        for j in range(n_filters):
            back_grad[i, :, :] += convolve2d_fast(grad[j, :, :], W[i, j, :, :], mode='valid')
    for i in range(n_feature_maps):
        if not back_linear:
            back_grad[i, :, :] = back_grad[i, :, :] * input_images[i, :, :] * (1 - input_images[i, :, :])
    for i in range(n_feature_maps):
        for j in range(n_filters):
            grad_W[i, j, :, :] = convolve2d_fast(np.rot90(grad[j, :, :], 2), input_images[i, :, :], mode='valid')
    grad_W += decay * W
    for i in range(n_filters):
        grad_b[i] = np.sum(grad[i])
    return back_grad, grad_W, grad_b

wide_convolution_layer_backward_standard = wide_convolution_layer_backward_2
wide_convolution_layer_backward_fast = wide_convolution_layer_backward_1


def k_max_pooling_image_1(np.ndarray[FLOAT_TYPE_t, ndim=2] input_image, \
                          const unsigned int k, np.ndarray[FLOAT_TYPE_t, ndim=1] b):
    cdef unsigned int rows = input_image.shape[0]
    cdef np.ndarray[U_INT_TYPE_t, ndim=2] kmax_index = np.sort(bn.argpartsort(-input_image, k)[:, :k]).astype(U_INT_TYPE)
    cdef np.ndarray[FLOAT_TYPE_t, ndim=2] new_image = np.empty((rows, k), dtype=FLOAT_TYPE)
    cdef unsigned int i, j
    for i in range(rows):
        for j in range(k):
            new_image[i, j] = input_image[i, <unsigned int> kmax_index[i, j]] + b[i]
    return new_image, kmax_index

def k_max_pooling_image_2(input_image, k, b):
    # b is with shape of (input_image.shape[0],)
    # return k max pooling image and k max indexes
    kmax_index = np.sort(np.argsort(input_image)[:, -k:])
    new_image = np.empty((input_image.shape[0], k), dtype=FLOAT_TYPE)
    for j in range(input_image.shape[0]):
        new_image[j, :] = input_image[j, kmax_index[j, :]] + b[j]
    return new_image, kmax_index

def k_max_pooling_image_3(np.ndarray[FLOAT_TYPE_t, ndim=2] input_image, const int k, \
                          np.ndarray[FLOAT_TYPE_t, ndim=1] b):
    cdef unsigned int rows = input_image.shape[0]
    cdef np.ndarray[U_INT_TYPE_t, ndim=2] kmax_index = np.sort(np.argsort(input_image)[:, -k:]).astype(U_INT_TYPE)
    cdef np.ndarray[FLOAT_TYPE_t, ndim=2] new_image = np.empty((rows, k), dtype=FLOAT_TYPE)
    cdef unsigned int i, j
    for i in range(rows):
        for j in range(k):
            new_image[i, j] = input_image[i, <unsigned int> kmax_index[i, j]] + b[i]
    return new_image, kmax_index

k_max_pooling_image_standard = k_max_pooling_image_2
k_max_pooling_image_fast = k_max_pooling_image_1

def k_max_pooling_backward_1(np.ndarray[FLOAT_TYPE_t, ndim=3] grad, \
                             np.ndarray[FLOAT_TYPE_t, ndim=3] input_images, \
                             np.ndarray[U_INT_TYPE_t, ndim=3] kmax_index, \
                             object back_linear=True):
    # grad should be with shape of (n_feature_maps, n_feature_rows, k)
    # input should be with shape of (n_feature_maps, n_feature_rows, n_feature_cols)
    # kmax_index should be with shape of (n_feature_maps, n_feature_rows, k)
    cdef unsigned int n_feature_maps = input_images.shape[0]
    cdef unsigned int n_feature_rows = input_images.shape[1]
    cdef unsigned int n_feature_cols = input_images.shape[2]
    cdef unsigned int k = grad.shape[2]
    cdef np.ndarray[FLOAT_TYPE_t, ndim=3] back_grad = np.zeros((n_feature_maps, n_feature_rows, n_feature_cols), dtype=FLOAT_TYPE)
    cdef np.ndarray[FLOAT_TYPE_t, ndim=2] grad_b = np.zeros((n_feature_maps, n_feature_rows), dtype=FLOAT_TYPE)
    cdef unsigned int i, j, m
    for i in range(n_feature_maps):
        for j in range(n_feature_rows):
            for m in range(k):
                back_grad[i, j, <unsigned int> kmax_index[i, j, m]] = grad[i, j, m]
                grad_b[i, j] += grad[i, j, m]
        if not back_linear:
            for j in range(n_feature_rows):
                for m in range(n_feature_cols):
                    back_grad[i, j, m] = back_grad[i, j, m] * input_images[i, j, m] * (1 - input_images[i, j, m])
    return back_grad, grad_b

def k_max_pooling_backward_2(grad, input_images, kmax_index, back_linear=True):
    # input should be with shape of (n_feature_maps, n_feature_rows, n_feature_cols)
    # kmax_index should be with shape of (n_feature_maps, n_feature_rows, k)
    n_feature_maps, n_feature_rows, n_feature_cols = input_images.shape
    back_grad = np.zeros(input_images.shape, dtype=FLOAT_TYPE)
    grad_b = np.zeros((n_feature_maps, n_feature_rows), dtype=FLOAT_TYPE)
    for i in range(n_feature_maps):
        for j in range(n_feature_rows):
            back_grad[i, j, kmax_index[i, j, :]] = grad[i, j, :]
            grad_b[i, j] = np.sum(grad[i, j, :])
        if not back_linear:
            back_grad[i, :, :] = back_grad[i, :, :] * input_images[i, :, :] * (1 - input_images[i, :, :])
    return back_grad, grad_b

k_max_pooling_backward_standard = k_max_pooling_backward_2
k_max_pooling_backward_fast = k_max_pooling_backward_1

def folding_image_1(np.ndarray[FLOAT_TYPE_t, ndim=2] input_image):
    cdef unsigned int i, k
    cdef unsigned int rows = input_image.shape[0]
    cdef unsigned int cols = input_image.shape[1]
    cdef np.ndarray[FLOAT_TYPE_t, ndim=2] output_image = np.empty((rows/2, cols), dtype=FLOAT_TYPE)
    for j in range(0, cols):
        for k in range(0, rows, 2):
            output_image[<unsigned int> k/2, j] = input_image[k,j] +input_image[<unsigned int> k+1,j]
    return output_image

def folding_image_2(input_image):
    output_image = np.empty((input_image.shape[0]/2, input_image.shape[1]), dtype=FLOAT_TYPE)
    for j in range(0, input_image.shape[1]):
        for k in range(0, input_image.shape[0], 2):
            output_image[k / 2, j] = input_image[k, j] + input_image[k + 1, j]
    return output_image

folding_image_standard = folding_image_2
folding_image_fast = folding_image_1

def dropout_sample_1d(const unsigned int size, const FLOAT_TYPE_t threshold):
    cdef np.ndarray[FLOAT_TYPE_t, ndim=1] prob
    cdef unsigned int i
    cdef object index = []

    prob = np.random.rand(size).astype(FLOAT_TYPE)
    for i in range(size):
        if prob[i] < threshold:
            index.append(i)
    return index

def dropout_sample_2d(const unsigned int rows, const unsigned int cols, const FLOAT_TYPE_t threshold):
    cdef np.ndarray[FLOAT_TYPE_t, ndim=2] prob
    cdef unsigned int i, j
    cdef object row_index = []
    cdef object col_index = []
    prob = np.random.rand(rows*cols).astype(FLOAT_TYPE).reshape(rows, cols)
    for i in range(rows):
        for j in range(cols):
            if prob[i, j] < threshold:
                row_index.append(i)
                col_index.append(j)
    return row_index, col_index
